import torch
import torch.nn as nn
import torch.nn.functional as F
import sys
import numpy as np


class STRAttnGraphConv(nn.Module):
    def __init__(self, n_route, n_his, d_attribute, d_out, dis_mat, n_head=4, d_q=32, d_k=64, d_c=10, kt=2, normal=False) -> None:
        super(STRAttnGraphConv, self).__init__()
        self.W_R1 = nn.Parameter(torch.empty(d_attribute*n_his, d_c))
        self.W_R2 = nn.Parameter(torch.empty(d_attribute*n_his, d_c))
        nn.init.xavier_uniform_(self.W_R1.data)
        nn.init.xavier_uniform_(self.W_R2.data)

        self.d_q = d_q
        self.d_k = d_k
        self.d_out = d_out
        self.d_attribute = d_attribute
        self.w_2 = nn.Linear(d_attribute, d_out, bias=False)
        self.w_stack = nn.ModuleList([
            nn.Linear(d_attribute, d_out, bias=False) for _ in range(kt)
        ])

        self.distant_mat = dis_mat

        self.no = normal
        if self.no:
            self.norm = nn.LayerNorm(d_attribute, eps=1e-6)

        self.ets = None

    def forward(self, x):
        residual = x
        b, n, t, k = x.size()

        r1 = residual.reshape(b, n, k*t)@self.W_R1
        r2 = residual.reshape(b, n, k*t)@self.W_R1
        res_graph = torch.matmul(r1, r2.transpose(1, 2))

        adj = torch.softmax(torch.relu(res_graph), dim=2).unsqueeze(1)

        x = x.permute(0, 2, 1, 3)
        z = self.w_2(x.reshape(-1, k)).reshape(-1, t, n, self.d_out)

        for w in self.w_stack:
            z = torch.matmul(adj, z)+w(x.reshape(-1, k)
                                       ).reshape(-1, t, n, self.d_out)
        z = z.permute(0, 2, 1, 3).contiguous()

        if self.no:
            z = z + residual
            z = self.norm(z)

        return z


class STGAttnGraphConv(nn.Module):
    def __init__(self, n_route, n_his, d_attribute, d_out, dis_mat, n_head=4, d_q=32, d_k=64, d_c=10, kt=2, normal=False) -> None:
        super(STGAttnGraphConv, self).__init__()
        self.V_S = nn.Parameter(torch.empty(n_route, d_c))
        nn.init.xavier_uniform_(self.V_S.data)
        self.V_T = nn.Parameter(torch.empty(n_route, d_c))
        nn.init.xavier_uniform_(self.V_T.data)

        self.d_q = d_q
        self.d_k = d_k
        self.d_out = d_out
        self.d_attribute = d_attribute
        self.w_2 = nn.Linear(d_attribute, d_out, bias=False)
        self.w_stack = nn.ModuleList([
            nn.Linear(d_attribute, d_out, bias=False) for _ in range(kt)
        ])

        self.distant_mat = dis_mat

        self.no = normal
        if self.no:
            self.norm = nn.LayerNorm(d_attribute, eps=1e-6)

        self.ets = None
        self.d_c = d_c

    def forward(self, x):
        residual = x
        b, n, t, k = x.size()

        et = self.V_T
        es = self.V_S

        adj = torch.matmul(es, et.transpose(1, 0))
        adj = torch.softmax(torch.relu(adj), dim=1)

        x = x.permute(0, 2, 1, 3)
        z = self.w_2(x.reshape(-1, k)).reshape(-1, t, n, self.d_out)

        for w in self.w_stack:
            z = torch.matmul(adj, z)+w(x.reshape(-1, k)
                                       ).reshape(-1, t, n, self.d_out)
        z = z.permute(0, 2, 1, 3).contiguous()

        if self.no:
            z = z + residual
            z = self.norm(z)

        return z


class STAttnGraphConv(nn.Module):
    def __init__(self, n_route, n_his, d_attribute, d_out, dis_mat, n_head=4, d_q=32, d_k=64, d_c=10, kt=2, normal=False) -> None:
        super(STAttnGraphConv, self).__init__()
        self.K_S = nn.Parameter(torch.empty(n_head, d_k))
        nn.init.xavier_uniform_(self.K_S.data)
        self.V_S = nn.Parameter(torch.empty(n_head, n_route, d_c))
        nn.init.xavier_uniform_(self.V_S.data)
        self.K_T = nn.Parameter(torch.empty(n_head, d_k))
        nn.init.xavier_uniform_(self.K_T.data)
        self.V_T = nn.Parameter(torch.empty(n_head, n_route, d_c))
        nn.init.xavier_uniform_(self.V_T.data)

        self.Q_0 = nn.Linear(d_attribute, d_q, bias=False)
        self.Q_S = nn.Linear(n_route*d_q, d_k, bias=False)
        self.Q_T = nn.Linear(n_his*d_q, d_k, bias=False)

        self.d_q = d_q
        self.d_k = d_k
        self.d_c = d_c
        self.n_route = n_route
        self.d_out = d_out
        self.d_attribute = d_attribute
        self.w_2 = nn.Linear(d_attribute, d_out, bias=False)
        self.w_stack = nn.ModuleList([
            nn.Linear(d_attribute, d_out, bias=False) for _ in range(kt)
        ])

        self.dis_drop = nn.Dropout(0.0)
        self.distant_mat = dis_mat

        self.no = normal
        if self.no:
            self.norm = nn.LayerNorm(d_attribute, eps=1e-6)

    def forward(self, x):
        self.distant_mat = self.dis_drop(self.distant_mat)

        residual = x
        b, n, t, k = x.size()

        out = self.Q_0(x.reshape(-1, k)).reshape(b, n, t, self.d_q)

        qt = self.Q_T(out.mean(dim=1).reshape(b, t*self.d_q))
        qs = self.Q_S(out.mean(dim=2).reshape(b, n*self.d_q))

        attn_t = torch.softmax(torch.matmul(
            qt/self.d_k**0.5, self.K_T.transpose(0, 1)), dim=1)
        et = torch.matmul(self.V_T.permute(1, 2, 0).unsqueeze(2).unsqueeze(
            0), attn_t.unsqueeze(2).unsqueeze(1).unsqueeze(1)).squeeze(-1).squeeze(-1)
        attn_s = torch.softmax(torch.matmul(
            qs/self.d_k**0.5, self.K_S.transpose(0, 1)), dim=1)
        es = torch.matmul(self.V_S.permute(1, 2, 0).unsqueeze(2).unsqueeze(
            0), attn_s.unsqueeze(2).unsqueeze(1).unsqueeze(1)).squeeze(-1).squeeze(-1)

        ets = et + es
        # ets = et    # GT
        # ets = es    # GS
        # adj = torch.matmul(ets, ets.transpose(1, 2))
        adj = 0
        adj = self.distant_mat + adj    # [50, 228, 228] [228, 228]
        # adj = torch.softmax(torch.relu(adj), dim=2).unsqueeze(1)

        # adj = torch.softmax(adj, dim=2).unsqueeze(1)    # GSTNR
        adj = torch.softmax(torch.relu(self.distant_mat), dim=1)  # L

        x = x.permute(0, 2, 1, 3)
        z = self.w_2(x.reshape(-1, k)).reshape(-1, t, n, self.d_out)

        for w in self.w_stack:
            z = torch.matmul(adj, z)+w(x.reshape(-1, k)
                                       ).reshape(-1, t, n, self.d_out)
        z = z.permute(0, 2, 1, 3).contiguous()

        if self.no:
            z = z + residual
            z = self.norm(z)

        return z


class STAttnGraphConv_verO(nn.Module):
    def __init__(self, n_route, n_his, d_attribute, d_out, dis_mat, n_head=4, d_q=32, d_k=64, d_c=10, kt=2, normal=False) -> None:
        super(STAttnGraphConv_verO, self).__init__()
        self.K_S = nn.Parameter(torch.empty(n_head, d_k))
        nn.init.xavier_uniform_(self.K_S.data)
        self.V_S = nn.Parameter(torch.empty(n_head, n_route, d_c))
        nn.init.xavier_uniform_(self.V_S.data)
        self.K_T = nn.Parameter(torch.empty(n_head, d_k))
        nn.init.xavier_uniform_(self.K_T.data)
        self.V_T = nn.Parameter(torch.empty(n_head, n_route, d_c))
        nn.init.xavier_uniform_(self.V_T.data)

        self.Q_0 = nn.Linear(d_attribute, d_q, bias=False)
        self.Q_S = nn.Linear(n_route*d_q, d_k, bias=False)
        self.Q_T = nn.Linear(n_his*d_q, d_k, bias=False)

        self.d_q = d_q
        self.d_k = d_k
        self.d_c = d_c
        self.n_route = n_route
        self.d_out = d_out
        self.d_attribute = d_attribute
        self.w_2 = nn.Linear(d_attribute, d_out, bias=False)
        self.w_stack = nn.ModuleList([
            nn.Linear(d_attribute, d_out, bias=False) for _ in range(kt)
        ])

        # self.w_O = nn.Linear(d_attribute, d_out, bias=False)
        self.w2_stack = nn.ModuleList([
            nn.Linear(d_attribute, d_out, bias=False) for _ in range(kt)
        ])

        self.distant_mat = dis_mat

        self.no = normal
        if self.no:
            self.norm = nn.LayerNorm(d_attribute, eps=1e-6)

    def forward(self, x):
        residual = x
        b, n, t, k = x.size()

        out = self.Q_0(x.reshape(-1, k)).reshape(b, n, t, self.d_q)

        qt = self.Q_T(out.mean(dim=1).reshape(b, t*self.d_q))
        qs = self.Q_S(out.mean(dim=2).reshape(b, n*self.d_q))

        attn_t = torch.softmax(torch.matmul(
            qt/self.d_k**0.5, self.K_T.transpose(0, 1)), dim=1)
        et = torch.matmul(self.V_T.permute(1, 2, 0).unsqueeze(2).unsqueeze(
            0), attn_t.unsqueeze(2).unsqueeze(1).unsqueeze(1)).squeeze(-1).squeeze(-1)
        attn_s = torch.softmax(torch.matmul(
            qs/self.d_k**0.5, self.K_S.transpose(0, 1)), dim=1)
        es = torch.matmul(self.V_S.permute(1, 2, 0).unsqueeze(2).unsqueeze(
            0), attn_s.unsqueeze(2).unsqueeze(1).unsqueeze(1)).squeeze(-1).squeeze(-1)

        ets = et + es
        # ets = et    # GT
        # ets = es    # GS
        adj = torch.matmul(ets, ets.transpose(1, 2))
        adj = self.distant_mat + adj    # [50, 228, 228] [228, 228]
        adj = torch.softmax(torch.relu(adj), dim=2).unsqueeze(1)

        # adj = torch.softmax(adj, dim=2).unsqueeze(1)    # GSTNR
        # adj = torch.softmax(torch.relu(self.distant_mat), dim=1)  # L

        x = x.permute(0, 2, 1, 3)
        z = self.w_2(x.reshape(-1, k)).reshape(-1, t, n, self.d_out)

        for w in self.w_stack:
            z = torch.matmul(adj, z)+w(x.reshape(-1, k)
                                       ).reshape(-1, t, n, self.d_out)
        z = z.permute(0, 2, 1, 3).contiguous()

        z2 = self.w_2(x.reshape(-1, k)).reshape(-1, t, n, self.d_out)
        for w2 in self.w2_stack:
            z2 = torch.matmul(self.distant_mat, z2) + \
                w2(x.reshape(-1, k)).reshape(-1, t, n, self.d_out)
        z2 = z2.permute(0, 2, 1, 3).contiguous()

        z = z + z2

        if self.no:
            z = z + residual
            z = self.norm(z)

        return z


def laplacian(W):
    N, N = W.shape
    W = W+torch.eye(N).to(W.device)
    D = W.sum(axis=1)
    D = torch.diag(D**(-0.5))
    out = D@W@D
    return out


def matrix_fnorm(W):
    # W:(h,n,n) return (h)
    h, n, n = W.shape
    W = W**2
    norm = (W.sum(dim=1).sum(dim=1))**(0.5)
    return norm/(n**0.5)


class STAttnGraphConv_drop(nn.Module):
    def __init__(self, d_attribute, d_out, n_route, dis_mat, kt=2, droprate=0., temperature=1.0) -> None:
        super(STAttnGraphConv_drop, self).__init__()
        self.droprate = droprate
        self.w_2 = nn.Linear(d_attribute, d_out, bias=False)
        self.w_stack = nn.ModuleList([
            nn.Linear(d_attribute, d_out, bias=False) for _ in range(kt)
        ])
        self.r1 = nn.Parameter(torch.randn(n_route, 10))
        self.r2 = nn.Parameter(torch.randn(n_route, 10))
        nn.init.xavier_uniform_(self.r1.data)
        nn.init.xavier_uniform_(self.r2.data)
        self.temperature = temperature
        self.d_out = d_out
        self.distant_mat = dis_mat

    def forward(self, x):
        residual = x
        b, n, t, k = x.size()
        adj = self.distant_mat.clone()
        if self.training:
            postive_mask = self.distant_mat > 0
            adj[postive_mask] = F.dropout(adj[postive_mask], p=self.droprate)
        adj = laplacian(adj)
        # adj = torch.softmax(adj, dim=2).unsqueeze(1)    # GSTNR
        adj = torch.softmax(torch.relu(self.distant_mat+self.r1@self.r2.T) /
                            self.temperature, dim=1)  # L
        adj = F.dropout(adj, p=self.droprate)
        x = x.permute(0, 2, 1, 3)
        z = self.w_2(x.reshape(-1, k)).reshape(-1, t, n, self.d_out)
        for w in self.w_stack:
            z = torch.matmul(adj, z)+w(x.reshape(-1, k)
                                       ).reshape(-1, t, n, self.d_out)
        z = z.permute(0, 2, 1, 3).contiguous()

        return z

# ======================================TPGNN==================================

# class STAttnGraphConv_stamp(nn.Module):
#     # lap+outerwave
#     def __init__(self, d_attribute, d_out, n_route, n_his, dis_mat, cor_mat, kt=2, droprate=0., temperature=1.0) -> None:
#         super(STAttnGraphConv_stamp, self).__init__()
#         self.droprate = droprate
#         self.r1 = nn.Parameter(torch.randn(n_route, 10))
#         # self.r2 = nn.Parameter(torch.randn(n_route, 10))
#         self.w_stack = nn.Parameter(torch.randn(kt+1, d_attribute, d_out))
#         nn.init.xavier_uniform_(self.w_stack.data)
#         self.reduce_stamp = nn.Linear(n_his, 1, bias=False)
#         self.temp_1 = nn.Linear(d_attribute//4, kt+1)
#         # self.temp_2 = nn.Linear(d_attribute//4, kt+1)
#         self.temperature = temperature
#         self.d_out = d_out
#         self.distant_mat = dis_mat
#         self.cor_mat = cor_mat
#         self.kt = kt

#     def forward(self, x, stamp):
#         # x:(b,n,t,k) stamp:(b,t,k)
#         residual = x
#         b, n, t, k = x.size()
#         h, _, _ = self.w_stack.shape
#         w_stack = self.w_stack/(matrix_fnorm(self.w_stack).reshape(h, 1, 1))
#         # print(stamp.shape)
#         # (b,t,k)->(b,k,1)->(b,kt+1)
#         period_emb = self.reduce_stamp(stamp.permute(0, 2, 1)).squeeze(2)
#         temp_1 = self.temp_1(period_emb)
#         # temp_2 = self.temp_2(period_emb)
#         adj = self.distant_mat.clone()
#         # adj_2 = self.cor_mat.clone()
#         if self.training:
#             nonzero_mask = self.distant_mat > 0
#             adj[nonzero_mask] = F.dropout(adj[nonzero_mask], p=self.droprate)
#             # nonzero_mask = self.cor_mat.abs() > 0
#             # adj_2[nonzero_mask] = F.dropout(
#             #     adj_2[nonzero_mask], p=self.droprate)

#         adj_1 = laplacian(adj)
#         adj_2 = torch.softmax(torch.relu(
#             self.r1@self.r1.T)/self.temperature, dim=0)
#         adj_1 = F.dropout(adj_1, p=self.droprate)
#         adj_2 = F.dropout(adj_2, p=self.droprate)
#         # (b,n,t,k)->(b,t,n,k)
#         z = (x@w_stack[0])*(temp_1[:, 0].reshape(b, 1, 1, 1))
#         z = z.permute(0, 2, 1, 3).reshape(b*t, n, -1)
#         # for i in range(1, self.kt):
#         #     z = adj@z + \
#         #         (x@w_stack[i]*(temp[:, i].reshape(b, 1, 1, 1))
#         #          ).permute(0, 2, 1, 3).reshape(b*t, n, -1)
#         for i in range(1, self.kt):
#             z = adj_1@z + \
#                 (x@w_stack[i]*(temp_1[:, i].reshape(b, 1, 1, 1))
#                  ).permute(0, 2, 1, 3).reshape(b*t, n, -1)
#         z_fix = (x@w_stack[0])*(temp_1[:, 0].reshape(b, 1, 1, 1))
#         z_fix = z_fix.permute(0, 2, 1, 3).reshape(b*t, n, -1)
#         for i in range(1, self.kt):
#             z_fix = adj_2@z_fix + \
#                 (x@w_stack[i]*(temp_1[:, i].reshape(b, 1, 1, 1))
#                  ).permute(0, 2, 1, 3).reshape(b*t, n, -1)
#         z = z+z_fix

#         z = z.reshape(b, t, n, self.d_out).permute(0, 2, 1, 3)

#         return z


class TPGNN(nn.Module):
    # softlap+outerwave
    def __init__(self, d_attribute, d_out, n_route, n_his, dis_mat, kt=2, n_c=10, droprate=0., temperature=1.0) -> None:
        super(TPGNN, self).__init__()
        self.droprate = droprate
        print(n_route, n_c)
        self.r1 = nn.Parameter(torch.randn(n_route, n_c))
        # self.r2 = nn.Parameter(torch.randn(n_route, 10))
        self.w_stack = nn.Parameter(torch.randn(kt+1, d_attribute, d_out))
        nn.init.xavier_uniform_(self.w_stack.data)
        self.reduce_stamp = nn.Linear(n_his, 1, bias=False)
        self.temp_1 = nn.Linear(d_attribute//4, kt+1)
        # self.temp_2 = nn.Linear(d_attribute//4, kt+1)
        self.temperature = temperature
        self.d_out = d_out
        self.distant_mat = dis_mat

        self.kt = kt

    def forward(self, x, stamp):
        # x:(b,n,t,k) stamp:(b,t,k)
        residual = x
        b, n, t, k = x.size()
        h, _, _ = self.w_stack.shape
        w_stack = self.w_stack/(matrix_fnorm(self.w_stack).reshape(h, 1, 1))
        # print(stamp.shape)
        # (b,t,k)->(b,k,1)->(b,kt+1)
        period_emb = self.reduce_stamp(stamp.permute(0, 2, 1)).squeeze(2)
        temp_1 = self.temp_1(period_emb)
        # temp_2 = self.temp_2(period_emb)
        adj = self.distant_mat.clone()
        # adj_2 = self.cor_mat.clone()
        if self.training:
            nonzero_mask = self.distant_mat > 0
            adj[nonzero_mask] = F.dropout(adj[nonzero_mask], p=self.droprate)
            # nonzero_mask = self.cor_mat.abs() > 0
            # adj_2[nonzero_mask] = F.dropout(
            #     adj_2[nonzero_mask], p=self.droprate)

        adj_1 = torch.softmax(torch.relu(
            laplacian(adj))/self.temperature, dim=0)
        adj_2 = torch.softmax(torch.relu(
            self.r1@self.r1.T)/self.temperature, dim=0)
        adj_1 = F.dropout(adj_1, p=self.droprate)
        adj_2 = F.dropout(adj_2, p=self.droprate)
        # (b,n,t,k)->(b,t,n,k)
        z = (x@w_stack[0])*(temp_1[:, 0].reshape(b, 1, 1, 1))
        z = z.permute(0, 2, 1, 3).reshape(b*t, n, -1)
        # for i in range(1, self.kt):
        #     z = adj@z + \
        #         (x@w_stack[i]*(temp[:, i].reshape(b, 1, 1, 1))
        #          ).permute(0, 2, 1, 3).reshape(b*t, n, -1)
        for i in range(1, self.kt):
            z = adj_1@z + \
                (x@w_stack[i]*(temp_1[:, i].reshape(b, 1, 1, 1))
                 ).permute(0, 2, 1, 3).reshape(b*t, n, -1)
        z_fix = (x@w_stack[0])*(temp_1[:, 0].reshape(b, 1, 1, 1))
        z_fix = z_fix.permute(0, 2, 1, 3).reshape(b*t, n, -1)
        for i in range(1, self.kt):
            z_fix = adj_2@z_fix + \
                (x@w_stack[i]*(temp_1[:, i].reshape(b, 1, 1, 1))
                 ).permute(0, 2, 1, 3).reshape(b*t, n, -1)
        z = z+z_fix

        z = z.reshape(b, t, n, self.d_out).permute(0, 2, 1, 3)

        return z


class TPGNN_wostamp(nn.Module):
    # softlap+outerwave
    def __init__(self, d_attribute, d_out, n_route, n_his, dis_mat, kt=2, n_c=10, droprate=0., temperature=1.0) -> None:
        super(TPGNN_wostamp, self).__init__()
        self.droprate = droprate
        self.r1 = nn.Parameter(torch.randn(n_route, n_c))
        # self.r2 = nn.Parameter(torch.randn(n_route, 10))
        self.w_stack = nn.Parameter(torch.randn(kt+1, d_attribute, d_out))
        nn.init.xavier_uniform_(self.w_stack.data)
        self.reduce_stamp = nn.Linear(n_his, 1, bias=False)
        # self.temp_1 = nn.Linear(d_attribute//4, kt+1)
        self.temp_1 = nn.Parameter(torch.randn(kt+1))
        # self.temp_2 = nn.Linear(d_attribute//4, kt+1)
        self.temperature = temperature
        self.d_out = d_out
        self.distant_mat = dis_mat

        self.kt = kt

    def forward(self, x, stamp):
        # x:(b,n,t,k) stamp:(b,t,k)
        residual = x
        b, n, t, k = x.size()
        h, _, _ = self.w_stack.shape
        w_stack = self.w_stack/(matrix_fnorm(self.w_stack).reshape(h, 1, 1))
        # print(stamp.shape)
        # (b,t,k)->(b,k,1)->(b,kt+1)
        # period_emb = self.reduce_stamp(stamp.permute(0, 2, 1)).squeeze(2)
        temp_1 = self.temp_1
        # temp_2 = self.temp_2(period_emb)
        adj = self.distant_mat.clone()
        # adj_2 = self.cor_mat.clone()
        if self.training:
            nonzero_mask = self.distant_mat > 0
            adj[nonzero_mask] = F.dropout(adj[nonzero_mask], p=self.droprate)
            # nonzero_mask = self.cor_mat.abs() > 0
            # adj_2[nonzero_mask] = F.dropout(
            #     adj_2[nonzero_mask], p=self.droprate)

        adj_1 = torch.softmax(torch.relu(
            laplacian(adj))/self.temperature, dim=0)
        adj_2 = torch.softmax(torch.relu(
            self.r1@self.r1.T)/self.temperature, dim=0)
        adj_1 = F.dropout(adj_1, p=self.droprate)
        adj_2 = F.dropout(adj_2, p=self.droprate)
        # (b,n,t,k)->(b,t,n,k)
        z = (x@w_stack[0])*(temp_1[0])
        z = z.permute(0, 2, 1, 3).reshape(b*t, n, -1)
        # for i in range(1, self.kt):
        #     z = adj@z + \
        #         (x@w_stack[i]*(temp[:, i].reshape(b, 1, 1, 1))
        #          ).permute(0, 2, 1, 3).reshape(b*t, n, -1)
        for i in range(1, self.kt):
            z = adj_1@z + \
                (x@w_stack[i]*(temp_1[i])
                 ).permute(0, 2, 1, 3).reshape(b*t, n, -1)
        z_fix = (x@w_stack[0])*(temp_1[0])
        z_fix = z_fix.permute(0, 2, 1, 3).reshape(b*t, n, -1)
        for i in range(1, self.kt):
            z_fix = adj_2@z_fix + \
                (x@w_stack[i]*(temp_1[i])
                 ).permute(0, 2, 1, 3).reshape(b*t, n, -1)
        z = z+z_fix

        z = z.reshape(b, t, n, self.d_out).permute(0, 2, 1, 3)

        return z


class TPGNN_wooverview(nn.Module):
    # softlap+outerwave
    def __init__(self, d_attribute, d_out, n_route, n_his, dis_mat, kt=2, n_c=10, droprate=0., temperature=1.0) -> None:
        super(TPGNN_wooverview, self).__init__()
        self.droprate = droprate
        self.r1 = nn.Parameter(torch.randn(n_route, n_c))
        # self.r2 = nn.Parameter(torch.randn(n_route, 10))
        self.w_stack = nn.Parameter(torch.randn(kt+1, d_attribute, d_out))
        nn.init.xavier_uniform_(self.w_stack.data)
        self.reduce_stamp = nn.Linear(n_his, 1, bias=False)
        self.temp_1 = nn.Linear(d_attribute//4, kt+1)
        # self.temp_2 = nn.Linear(d_attribute//4, kt+1)
        self.temperature = temperature
        self.d_out = d_out
        self.distant_mat = dis_mat

        self.kt = kt

    def forward(self, x, stamp):
        # x:(b,n,t,k) stamp:(b,t,k)
        residual = x
        b, n, t, k = x.size()
        h, _, _ = self.w_stack.shape
        w_stack = self.w_stack/(matrix_fnorm(self.w_stack).reshape(h, 1, 1))
        # print(stamp.shape)
        # (b,t,k)->(b,k,1)->(b,kt+1)
        # period_emb = self.reduce_stamp(stamp.permute(0, 2, 1)).squeeze(2)
        # (b,t,k)->(b,t,kt+1)
        temp_1 = self.temp_1(stamp)
        # temp_2 = self.temp_2(period_emb)
        adj = self.distant_mat.clone()
        # adj_2 = self.cor_mat.clone()
        if self.training:
            nonzero_mask = self.distant_mat > 0
            adj[nonzero_mask] = F.dropout(adj[nonzero_mask], p=self.droprate)
            # nonzero_mask = self.cor_mat.abs() > 0
            # adj_2[nonzero_mask] = F.dropout(
            #     adj_2[nonzero_mask], p=self.droprate)

        adj_1 = torch.softmax(torch.relu(
            laplacian(adj))/self.temperature, dim=0)
        adj_2 = torch.softmax(torch.relu(
            self.r1@self.r1.T)/self.temperature, dim=0)
        adj_1 = F.dropout(adj_1, p=self.droprate)
        adj_2 = F.dropout(adj_2, p=self.droprate)
        # temp (b,t,kt+1)
        z = (x@w_stack[0])*(temp_1[:, :, 0]).reshape(b, 1, t, 1)
        z = z.permute(0, 2, 1, 3).reshape(b*t, n, -1)
        for i in range(1, self.kt):
            z = adj_1@z + \
                (x@w_stack[i]*(temp_1[:, :, i].reshape(b, 1, t, 1))
                 ).permute(0, 2, 1, 3).reshape(b*t, n, -1)
        z_fix = (x@w_stack[0])*(temp_1[:, :, 0].reshape(b, 1, t, 1))
        z_fix = z_fix.permute(0, 2, 1, 3).reshape(b*t, n, -1)
        for i in range(1, self.kt):
            z_fix = adj_2@z_fix + \
                (x@w_stack[i]*(temp_1[:, :, i].reshape(b, 1, t, 1))
                 ).permute(0, 2, 1, 3).reshape(b*t, n, -1)
        z = z+z_fix

        z = z.reshape(b, t, n, self.d_out).permute(0, 2, 1, 3)
        return z


class TPGNN_wonormalize(nn.Module):
    # softlap+outerwave
    def __init__(self, d_attribute, d_out, n_route, n_his, dis_mat, kt=2, n_c=10, droprate=0., temperature=1.0) -> None:
        super(TPGNN_wonormalize, self).__init__()
        self.droprate = droprate
        self.r1 = nn.Parameter(torch.randn(n_route, n_c))
        # self.r2 = nn.Parameter(torch.randn(n_route, 10))
        self.w_stack = nn.Parameter(torch.randn(kt+1, d_attribute, d_out))
        nn.init.xavier_uniform_(self.w_stack.data)
        self.reduce_stamp = nn.Linear(n_his, 1, bias=False)
        self.temp_1 = nn.Linear(d_attribute//4, kt+1)
        # self.temp_2 = nn.Linear(d_attribute//4, kt+1)
        self.temperature = temperature
        self.d_out = d_out
        self.distant_mat = dis_mat

        self.kt = kt

    def forward(self, x, stamp):
        # x:(b,n,t,k) stamp:(b,t,k)
        residual = x
        b, n, t, k = x.size()
        # h, _, _ = self.w_stack.shape
        w_stack = self.w_stack
        # print(stamp.shape)
        # (b,t,k)->(b,k,1)->(b,kt+1)
        period_emb = self.reduce_stamp(stamp.permute(0, 2, 1)).squeeze(2)
        temp_1 = self.temp_1(period_emb)
        # temp_2 = self.temp_2(period_emb)
        adj = self.distant_mat.clone()
        # adj_2 = self.cor_mat.clone()
        if self.training:
            nonzero_mask = self.distant_mat > 0
            adj[nonzero_mask] = F.dropout(adj[nonzero_mask], p=self.droprate)
            # nonzero_mask = self.cor_mat.abs() > 0
            # adj_2[nonzero_mask] = F.dropout(
            #     adj_2[nonzero_mask], p=self.droprate)

        adj_1 = torch.softmax(torch.relu(
            laplacian(adj))/self.temperature, dim=0)
        adj_2 = torch.softmax(torch.relu(
            self.r1@self.r1.T)/self.temperature, dim=0)
        adj_1 = F.dropout(adj_1, p=self.droprate)
        adj_2 = F.dropout(adj_2, p=self.droprate)
        # (b,n,t,k)->(b,t,n,k)
        z = (x@w_stack[0])*(temp_1[:, 0].reshape(b, 1, 1, 1))
        z = z.permute(0, 2, 1, 3).reshape(b*t, n, -1)
        # for i in range(1, self.kt):
        #     z = adj@z + \
        #         (x@w_stack[i]*(temp[:, i].reshape(b, 1, 1, 1))
        #          ).permute(0, 2, 1, 3).reshape(b*t, n, -1)
        for i in range(1, self.kt):
            z = adj_1@z + \
                (x@w_stack[i]*(temp_1[:, i].reshape(b, 1, 1, 1))
                 ).permute(0, 2, 1, 3).reshape(b*t, n, -1)
        z_fix = (x@w_stack[0])*(temp_1[:, 0].reshape(b, 1, 1, 1))
        z_fix = z_fix.permute(0, 2, 1, 3).reshape(b*t, n, -1)
        for i in range(1, self.kt):
            z_fix = adj_2@z_fix + \
                (x@w_stack[i]*(temp_1[:, i].reshape(b, 1, 1, 1))
                 ).permute(0, 2, 1, 3).reshape(b*t, n, -1)
        z = z+z_fix

        z = z.reshape(b, t, n, self.d_out).permute(0, 2, 1, 3)

        return z


class TPGNN_wokmatrices(nn.Module):
    # softlap+outerwave
    def __init__(self, d_attribute, d_out, n_route, n_his, dis_mat, kt=2, n_c=10, droprate=0., temperature=1.0) -> None:
        super(TPGNN_wokmatrices, self).__init__()
        self.droprate = droprate
        self.r1 = nn.Parameter(torch.randn(n_route, n_c))
        # self.r2 = nn.Parameter(torch.randn(n_route, 10))
        self.w_stack = nn.Parameter(torch.randn(2, d_attribute, d_out))
        nn.init.xavier_uniform_(self.w_stack.data)
        self.reduce_stamp = nn.Linear(n_his, 1, bias=False)
        self.temp_1 = nn.Linear(d_attribute//4, kt+1)
        # self.temp_2 = nn.Linear(d_attribute//4, kt+1)
        self.temperature = temperature
        self.d_out = d_out
        self.distant_mat = dis_mat

        self.kt = kt

    def forward(self, x, stamp):
        # x:(b,n,t,k) stamp:(b,t,k)
        residual = x
        b, n, t, k = x.size()
        h, _, _ = self.w_stack.shape
        w_stack = self.w_stack/(matrix_fnorm(self.w_stack).reshape(h, 1, 1))
        # print(stamp.shape)
        # (b,t,k)->(b,k,1)->(b,kt+1)
        period_emb = self.reduce_stamp(stamp.permute(0, 2, 1)).squeeze(2)
        temp_1 = self.temp_1(period_emb)
        # temp_2 = self.temp_2(period_emb)
        adj = self.distant_mat.clone()
        # adj_2 = self.cor_mat.clone()
        if self.training:
            nonzero_mask = self.distant_mat > 0
            adj[nonzero_mask] = F.dropout(adj[nonzero_mask], p=self.droprate)
            # nonzero_mask = self.cor_mat.abs() > 0
            # adj_2[nonzero_mask] = F.dropout(
            #     adj_2[nonzero_mask], p=self.droprate)

        adj_1 = torch.softmax(torch.relu(
            laplacian(adj))/self.temperature, dim=0)
        adj_2 = torch.softmax(torch.relu(
            self.r1@self.r1.T)/self.temperature, dim=0)
        adj_1 = F.dropout(adj_1, p=self.droprate)
        adj_2 = F.dropout(adj_2, p=self.droprate)
        # (b,n,t,k)->(b,t,n,k)
        z = (x@w_stack[0])*(temp_1[:, 0].reshape(b, 1, 1, 1))
        z = z.permute(0, 2, 1, 3).reshape(b*t, n, -1)
        # for i in range(1, self.kt):
        #     z = adj@z + \
        #         (x@w_stack[i]*(temp[:, i].reshape(b, 1, 1, 1))
        #          ).permute(0, 2, 1, 3).reshape(b*t, n, -1)
        for i in range(1, self.kt):
            z = adj_1@z + \
                (x@self.w_stack[1]*(temp_1[:, i].reshape(b, 1, 1, 1))
                 ).permute(0, 2, 1, 3).reshape(b*t, n, -1)
        z_fix = (x@w_stack[0])*(temp_1[:, 0].reshape(b, 1, 1, 1))
        z_fix = z_fix.permute(0, 2, 1, 3).reshape(b*t, n, -1)
        for i in range(1, self.kt):
            z_fix = adj_2@z_fix + \
                (x@w_stack[1]*(temp_1[:, i].reshape(b, 1, 1, 1))
                 ).permute(0, 2, 1, 3).reshape(b*t, n, -1)
        z = z+z_fix

        z = z.reshape(b, t, n, self.d_out).permute(0, 2, 1, 3)

        return z


class TPGablation():
    def __init__(self) -> None:
        self.module_dict = {
            "TPGNN": TPGNN,
            "TPGNN_wostamp": TPGNN_wostamp,
            "TPGNN_wooverview": TPGNN_wooverview,
            "TPGNN_wonormalize": TPGNN_wonormalize,
            "TPGNN_wokmatrices": TPGNN_wokmatrices
        }

    def get(self, modulename, d_attribute, d_out, n_route, n_his, dis_mat, kt=2, n_c=10, droprate=0., temperature=1.0):
        return self.module_dict[modulename](d_attribute, d_out, n_route, n_his, dis_mat, kt, n_c, droprate, temperature)
# class STAttnGraphConv_stamp(nn.Module):
#     # softlap+softcor
#     def __init__(self, d_attribute, d_out, n_route, n_his, dis_mat, cor_mat, kt=2, droprate=0., temperature=1.0) -> None:
#         super(STAttnGraphConv_stamp, self).__init__()
#         self.droprate = droprate
#         # self.r1 = nn.Parameter(torch.randn(n_route, 10))
#         # self.r2 = nn.Parameter(torch.randn(n_route, 10))
#         self.w_stack = nn.Parameter(torch.randn(kt+1, d_attribute, d_out))
#         nn.init.xavier_uniform_(self.w_stack.data)
#         self.reduce_stamp = nn.Linear(n_his, 1, bias=False)
#         self.temp_1 = nn.Linear(d_attribute//4, kt+1)
#         self.temp_2 = nn.Linear(d_attribute//4, kt+1)
#         self.temperature = temperature
#         self.d_out = d_out
#         self.distant_mat = dis_mat
#         self.cor_mat = cor_mat
#         self.kt = kt

#     def forward(self, x, stamp):
#         # x:(b,n,t,k) stamp:(b,t,k)
#         residual = x
#         b, n, t, k = x.size()
#         h, _, _ = self.w_stack.shape
#         w_stack = self.w_stack/(matrix_fnorm(self.w_stack).reshape(h, 1, 1))
#         # print(stamp.shape)
#         # (b,t,k)->(b,k,1)->(b,kt+1)
#         period_emb = self.reduce_stamp(stamp.permute(0, 2, 1)).squeeze(2)
#         temp_1 = self.temp_1(period_emb)
#         temp_2 = self.temp_2(period_emb)
#         adj_1 = self.distant_mat.clone()
#         adj_2 = self.cor_mat.clone()
#         if self.training:
#             nonzero_mask = self.distant_mat > 0
#             adj_1[nonzero_mask] = F.dropout(
#                 adj_1[nonzero_mask], p=self.droprate)
#             nonzero_mask = self.cor_mat.abs() > 0
#             adj_2[nonzero_mask] = F.dropout(
#                 adj_2[nonzero_mask], p=self.droprate)

#         adj_1 = torch.softmax(torch.relu(
#             laplacian(adj_1))/self.temperature, dim=0)
#         adj_2 = torch.softmax(torch.relu(laplacian(adj_2)) /
#                               self.temperature, dim=0)
#         adj_1 = F.dropout(adj_1, p=self.droprate)
#         adj_2 = F.dropout(adj_2, p=self.droprate)
#         # (b,n,t,k)->(b,t,n,k)
#         z = (x@w_stack[0])*(temp_1[:, 0].reshape(b, 1, 1, 1))
#         z = z.permute(0, 2, 1, 3).reshape(b*t, n, -1)
#         # for i in range(1, self.kt):
#         #     z = adj@z + \
#         #         (x@w_stack[i]*(temp[:, i].reshape(b, 1, 1, 1))
#         #          ).permute(0, 2, 1, 3).reshape(b*t, n, -1)
#         for i in range(1, self.kt):
#             z = adj_1@z + \
#                 (x@w_stack[i]*(temp_1[:, i].reshape(b, 1, 1, 1))
#                  ).permute(0, 2, 1, 3).reshape(b*t, n, -1)
#         z_fix = (x@w_stack[0])*(temp_2[:, 0].reshape(b, 1, 1, 1))
#         z_fix = z_fix.permute(0, 2, 1, 3).reshape(b*t, n, -1)
#         for i in range(1, self.kt):
#             z_fix = adj_2@z_fix + \
#                 (x@w_stack[i]*(temp_2[:, i].reshape(b, 1, 1, 1))
#                  ).permute(0, 2, 1, 3).reshape(b*t, n, -1)
#         z = z+z_fix

#         z = z.reshape(b, t, n, self.d_out).permute(0, 2, 1, 3)

#         return z


# class STAttnGraphConv_stamp(nn.Module):
#     # softlap+softcor+nonormalized
#     def __init__(self, d_attribute, d_out, n_route, n_his, dis_mat, cor_mat, kt=2, droprate=0., temperature=1.0) -> None:
#         super(STAttnGraphConv_stamp, self).__init__()
#         self.droprate = droprate
#         # self.r1 = nn.Parameter(torch.randn(n_route, 10))
#         # self.r2 = nn.Parameter(torch.randn(n_route, 10))
#         self.w_stack = nn.Parameter(torch.randn(kt+1, d_attribute, d_out))
#         nn.init.xavier_uniform_(self.w_stack.data)
#         self.reduce_stamp = nn.Linear(n_his, 1, bias=False)
#         self.temp_1 = nn.Linear(d_attribute//4, kt+1)
#         self.temp_2 = nn.Linear(d_attribute//4, kt+1)
#         self.temperature = temperature
#         self.d_out = d_out
#         self.distant_mat = dis_mat
#         self.cor_mat = cor_mat
#         self.kt = kt

#     def forward(self, x, stamp):
#         # x:(b,n,t,k) stamp:(b,t,k)
#         residual = x
#         b, n, t, k = x.size()
#         h, _, _ = self.w_stack.shape
#         w_stack = self.w_stack/(matrix_fnorm(self.w_stack).reshape(h, 1, 1))
#         # print(stamp.shape)
#         # (b,t,k)->(b,k,1)->(b,kt+1)
#         period_emb = self.reduce_stamp(stamp.permute(0, 2, 1)).squeeze(2)
#         temp_1 = self.temp_1(period_emb)
#         temp_2 = self.temp_2(period_emb)
#         adj_1 = self.distant_mat.clone()
#         adj_2 = self.cor_mat.clone()
#         if self.training:
#             nonzero_mask = self.distant_mat > 0
#             adj_1[nonzero_mask] = F.dropout(
#                 adj_1[nonzero_mask], p=self.droprate)
#             nonzero_mask = self.cor_mat.abs() > 0
#             adj_2[nonzero_mask] = F.dropout(
#                 adj_2[nonzero_mask], p=self.droprate)

#         adj_1 = torch.softmax(torch.relu(
#             adj_1+torch.eye(adj_1.shape[0]).cuda())/self.temperature, dim=0)
#         adj_2 = torch.softmax(torch.relu(adj_2+torch.eye(adj_2.shape[0]).cuda()) /
#                               self.temperature, dim=0)
#         adj_1 = F.dropout(adj_1, p=self.droprate)
#         adj_2 = F.dropout(adj_2, p=self.droprate)
#         # (b,n,t,k)->(b,t,n,k)
#         z = (x@w_stack[0])*(temp_1[:, 0].reshape(b, 1, 1, 1))
#         z = z.permute(0, 2, 1, 3).reshape(b*t, n, -1)
#         # for i in range(1, self.kt):
#         #     z = adj@z + \
#         #         (x@w_stack[i]*(temp[:, i].reshape(b, 1, 1, 1))
#         #          ).permute(0, 2, 1, 3).reshape(b*t, n, -1)
#         for i in range(1, self.kt):
#             z = adj_1@z + \
#                 (x@w_stack[i]*(temp_1[:, i].reshape(b, 1, 1, 1))
#                  ).permute(0, 2, 1, 3).reshape(b*t, n, -1)
#         z_fix = (x@w_stack[0])*(temp_2[:, 0].reshape(b, 1, 1, 1))
#         z_fix = z_fix.permute(0, 2, 1, 3).reshape(b*t, n, -1)
#         for i in range(1, self.kt):
#             z_fix = adj_2@z_fix + \
#                 (x@w_stack[i]*(temp_2[:, i].reshape(b, 1, 1, 1))
#                  ).permute(0, 2, 1, 3).reshape(b*t, n, -1)
#         z = z+z_fix

#         z = z.reshape(b, t, n, self.d_out).permute(0, 2, 1, 3)

#         return z
# ====================================TPGNN========================================

class gcnoperation(nn.Module):
    def __init__(self, dim_in, dim_out):
        super().__init__()
        self.dim_in = dim_in
        self.dim_out = dim_out
        self.linear = nn.Linear(dim_in, 2*dim_out, bias=False)

    def forward(self, x, adj):
        """
        x:(b,n,k1)
        adj:(b,n,n)
        ---
        out:(b,n,k2)
        """
        out = adj@x
        out = self.linear(out)
        return out[:, :, 0:self.dim_out]*torch.sigmoid(out[:, :, self.dim_out:2*self.dim_out])


class cascadegraphconv(nn.Module):
    def __init__(self, num_kernel, n_vertices, dim_in, dim_out):
        super().__init__()
        self.w_0 = gcnoperation(dim_in, dim_out)
        self.w_list = nn.ModuleList(
            [gcnoperation(dim_out, dim_out)for k in range(num_kernel-1)])
        self.n_vertices = n_vertices

    def forward(self, x, adj):
        """
        x: (b,n,k1)
        adj: (b,n,n)
        ---
        out: (b,n,k2)
        """
        b, n, k = x.size()
        feature_list = []
        x = self.w_0(x, adj)
        feature_list.append(x)
        for w in self.w_list:
            x = w(x, adj)
            feature_list.append(x)
        out = torch.stack(feature_list, 0)
        out, _ = torch.max(
            out[:, :, self.n_vertices:2*self.n_vertices, :], dim=0)
        return out


if __name__ == "__main__":
    model = TPGNN_wooverview(
        10, 10, 8, 12, torch.randn(8, 8), kt=3, n_c=4)
    x = torch.randn(3, 8, 12, 10)
    stamp = torch.randn(3, 12, 2)
    out = model(x, stamp)
    print(out.shape)
